import argparse
parser = argparse.ArgumentParser(prog="testname")
parser.add_argument('-f', help="这是一个帮助信息")
parser.add_argument('--arch', '-a', metavar='ARCH', default='NestedUNet',
                    choices=ARCH_NAMES,
                    help='model architecture: ' +
                         ' | '.join(ARCH_NAMES) +
                         ' (default: NestedUNet)')  # 指定网络架构
parser.add_argument('--deep_supervision', default=False, type=str2bool)
parser.add_argument('--input_channels', default=3, type=int,
                    help='input channels')
parser.add_argument('--num_classes', default=1, type=int,
                    help='number of classes')
parser.add_argument('--input_w', default=96, type=int,
                    help='image width')
parser.add_argument('--input_h', default=96, type=int,
                    help='image height')
parser.parse_args()

import os

import cv2
import numpy as np
import torch
import torch.utils.data


class Dataset(torch.utils.data.Dataset):
    def __init__(self, img_ids, img_dir, mask_dir, img_ext, mask_ext, num_classes, transform=None):
        self.img_ids = img_ids
        self.img_dir = img_dir
        self.mask_dir = mask_dir
        self.img_ext = img_ext
        self.mask_ext = mask_ext
        self.num_classes = num_classes
        self.transform = transform

    def __len__(self):
        return len(self.img_ids)

    def __getitem__(self, idx):
        img_id = self.img_ids[idx]
        # self.img_ids大小是536（即8*67），其中的idx指的应该是536中的id,但不是列表中第idx个，img_id获得的是图片的名字
        img = cv2.imread(os.path.join(self.img_dir, img_id + self.img_ext))
        # img(96,96,3)，获得一张图片，img_ext指的是拓展名（此代码中都是png）
        # cv2.imread读入图像
        # os.path.join将多个路径组合后返回,本代码返回的是图片路径+图片的id和图片的扩展名（也就是图片格式png或者jpg）
        mask = []
        for i in range(self.num_classes):
            mk = cv2.imread(os.path.join(self.mask_dir, str(i), img_id + self.mask_ext), cv2.IMREAD_GRAYSCALE)[
                ..., None]  # mk是一个（96，96，1）的数组，由（96，96）--->（96，96，1）
            mask.append(mk)
        mask = np.dstack(mask)  # 堆叠一个list,这个list是（96，96，1）的mk
        # 获得mask图像中的灰度图片
        # cv2.IMREAD_GRAYSCALE 以灰度模式加载图片，可以直接写0
        # [..., None] 在最后面追加一个新的维度

        if self.transform is not None:
            augmented = self.transform(image=img, mask=mask)  # 这个包比较方便，能把mask也一并做掉
            # 把image和mask一起放入augmented中。做了一些变化转化，但是不知道是什么转化
            # Augmented 里面有两个字典，分别是transform过的img和mask。img有变化、mask也有变化。
            img = augmented['image']  # 参考https://github.com/albumentations-team/albumentations
            # 把self.transform转换过的image输入给img，作为img的值
            mask = augmented['mask']  # mask没有变化

        img = img.astype('float32') / 255  # 又是一次对image的改变
        img = img.transpose(2, 0, 1)  # (96,96,3)---->(3,96,96)
        mask = mask.astype('float32') / 255  # 把mask变成0或1
        mask = mask.transpose(2, 0, 1)  # (96,96,1)---->(1,96,96)

        return img, mask, {'img_id': img_id}